#include "statsxx/machine_learning/activation_functions/ReLU.hpp"

// STL
#include <algorithm> // std::max()
#include <stdexcept> // std::runtime_error()

// jScience
#include "jrandnum.hpp" // rand_num_normal_Mersenne_twister()


inline activation_function::ReLU::ReLU() {};

inline activation_function::ReLU::~ReLU() {};


//     f(x) = max(0, x)
inline double activation_function::ReLU::f(const double x)
{
    return std::max(0., x);
}

// TODO: NOTE: the derivative is undefined at x = 0; unless one makes some convention:
//     f'(x) = {0   for   x <= 0
//             {1   for   x >  0
inline double activation_function::ReLU::df(const double x)
{
/*
    if(x > 0.)
    {
        return 1.;
    }
    else
    {
        return 0.;
    }
*/
    throw std::runtime_error("error in activation_function::ReLU::df(): not computed");
}

//     TODO: ambiguous for x < 0? 
inline double activation_function::ReLU::inv(const double x)
{
    throw std::runtime_error("error in activation_function::ReLU::inv(): not computed");
}